from gd import *
from .helpers import np_expit, softplus_thresholded, hardplus
import autograd.numpy as np
from .ContinuousModel import ContinuousModel
import random


class ContinuousModel2View:

    def __init__(self, view1, view2, model1_weight, model2_weight, num_subtrees=0, subtree_loadings=[None, None], max_depth=3, seed=None):
        self.view1 = view1
        self.view2 = view2
        self.num_subtrees = num_subtrees
        self.subtree_loadings = subtree_loadings
        self.max_depth = max_depth
        self.N = view1.shape[0]
        self.D1 = view1.shape[1]
        self.D2 = view2.shape[1]

        self.seed = seed
        np.random.seed(self.seed)
        random.seed(self.seed)

        self.iter_counter = 0
        self.prev_params = None

        self.model1 = ContinuousModel(view1, num_subtrees, subtree_loadings[0], sparse, greedy, max_depth, inference, seed)
        self.model2 = ContinuousModel(view2, num_subtrees, subtree_loadings[1], sparse, greedy, max_depth, inference, seed)

        self.model1.sigma_scale = 0.1
        self.model2.sigma_scale = 0.1

        self.model1_weight = model1_weight
        self.model2_weight = model2_weight

        self.nn = True
        self.model1.nn = self.nn
        self.model2.nn = self.nn

    def configure(self):
        self.model1.configure()
        self.model2.configure()

        self.n_latent_factors = self.model1.n_latent_factors

        latent_z = np.tile(
            np.array(0.5 * np.ones(self.N)), self.n_latent_factors)

        # could also make this random
        latent_sigma_view1 = np.array([0.1])
        latent_sigma_view2 = np.array([0.1])

        self.params = [latent_z, self.model1.latent_mu, self.model2.latent_mu, latent_sigma_view1, latent_sigma_view2]
        self.params = np.hstack(self.params)


    def set_sigmas(self, sigma_view1, sigma_view2):
        self.model1.latent_sigma = sigma_view1
        self.model2.latent_sigma = sigma_view2

    def set_latents(self, mus_view1, mus_view2, zs, sigma_view1, sigma_view2):
        self.model1.set_latents(mus_view1, zs, sigma_view1)
        self.model2.set_latents(mus_view2, zs, sigma_view2)

    def set_all_fixed(self):
        for factor in self.factors:
            factor.fixed = True

    def unpack_params(self, params):
        latent_zs = []
        latent_mus_view1 = []
        latent_mus_view2 = []
        offset_mu_view1 = self.n_latent_factors * self.N
        offset_mu_view2 = self.n_latent_factors * self.N + self.n_latent_factors * self.D1
        for factor_index in range(self.n_latent_factors):
            latent_zs.append(
                np_expit(params[factor_index * self.N:(factor_index + 1) * self.N]))
            latent_mus_view1.append(
                params[offset_mu_view1 + factor_index * self.D1:offset_mu_view1 + (factor_index + 1) * self.D1])
            latent_mus_view2.append(
                params[offset_mu_view2 + factor_index * self.D2:offset_mu_view2 + (factor_index + 1) * self.D2])

        latent_sigma_view1 = params[-2]
        latent_sigma_view2 = params[-1]

        if self.nn:
            return softplus_thresholded(np.array(latent_mus_view1)), softplus_thresholded(np.array(latent_mus_view2)), np.array(latent_zs), latent_sigma_view1, latent_sigma_view2
        else:
            return np.array(latent_mus_view1), np.array(latent_mus_view2), np.array(latent_zs), latent_sigma_view1, latent_sigma_view2

    def objective(self, params, t):
        latent_mus_view1, latent_mus_view2, latent_zs, latent_sigma_view1, latent_sigma_view2 = self.unpack_params(params)
        self.set_latents(latent_mus_view1, latent_mus_view2, latent_zs, latent_sigma_view1, latent_sigma_view2)
        
        view1_gaussian = np.sum(self.model1.gaussian_likelihood())
        view2_gaussian = np.sum(self.model2.gaussian_likelihood())
        z_likelihood = self.model1.z_likelihood()

        model1_objective = view1_gaussian + z_likelihood + self.model1.mu_prior() + self.model1.sigma_prior()
        model2_objective = view2_gaussian + z_likelihood + self.model2.mu_prior() + self.model2.sigma_prior()

        print('VIEW 1 Gaussian:', view1_gaussian)
        print('VIEW 2 Gaussian:', view2_gaussian)
        print('SIGMA1', self.model1.latent_sigma)
        print('SIGMA1', self.model2.latent_sigma)
        print('VIEW 1 OBJECTIVE:', model1_objective)
        print('VIEW 2 OBJECTIVE:', model2_objective)

        # only account for the prior on z once in the full model 
        return -1 * (self.model1_weight * (model1_objective - z_likelihood) + 
                     self.model2_weight * (model2_objective - z_likelihood) +
                     z_likelihood)

    def callback(self, params, t, g):
        # print(params)
        # if self.prev_params is not None:
        #     print('param diff:', '\n', params - self.prev_params)
        mus_view1, mus_view2, zs, sigma_view1, sigma_view2 = self.unpack_params(params)
        self.set_latents(mus_view1, mus_view2, zs, sigma_view1, sigma_view2)
        self.iter_counter += 1

        # print('gradient:', '\n', g)
        self.prev_params = params

    def update(self, x):
        x[-1] = hardplus(x[-1])
        x[-2] = hardplus(x[-2])
        if x[-1] == 0:
            x[-1] = 0.001
        if x[-2] == 0:
            x[-2] = 0.001
        return x
